Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) #10980

Open
wants to merge 50 commits into
base: main
Choose a base branch
from

Conversation

afeldman-nm
Copy link
Contributor

@afeldman-nm afeldman-nm commented Dec 7, 2024

This PR adds support for parallel sampling to v1 AsyncLLM and LLMEngine.

Parallel sampling is implemented outside the engine. This does not impact the vLLM v0 parallel sampling implementation.

Design doc: https://docs.google.com/document/d/1_fvbHVCfexkPAj2Vx0Q0gNvE-b53WwtrLrD6NldgBFU/edit?usp=sharing (message me if you need access permissions)

A request with n>1 will spawn n requests with n=1 and aggregate their outputs in accordance with the output_kind. If prefix caching is enabled, an initial warmup request with max_tokens=1 will be sent to the engine to fill the prefix cache. Update: The vLLM v1 engine can exploit APC when a prompt repeats within a batch, even if that prompt was not seen in a previous batch. Therefore, no warmup request is required.

The abstractions are cleanest for async v1 engine because AsyncLLM presents a generate() method that handles adding requests and running the engine; parallel sampling is implemented by writing a wrapper which invokes this method n times against parallel sampling child requests.

v1 LLMEngine presents add_request() and step() methods, so parallel sampling is implemented by writing an add_request() wrapper which branches parallel sampling requests into n add_request() calls for child requests, and then having step() aggregate child request outputs into a parent request output.

FIX #13419
FIX #13420

Copy link

github-actions bot commented Dec 7, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the frontend label Dec 7, 2024
Copy link

mergify bot commented Dec 13, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @afeldman-nm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@FrederickVu
Copy link

It seems like this PR is implementing ideas similar to those implemented in PR #9302 for the V0 engine. That PR created some issues that were addressed in PR #11898 and which may exist in the proposed V1 code.

In particular, the proposed code currently does not properly handle the case when a seed value is provided for the parent request; the seed value is duplicated in child requests, leading to identical outputs in the child requests. The fix in #11898 was simply to move the copying of the SamplingParams object inside a for-loop and to increment the seed value of the parent request.

Additionally, the proposed code for the V1 engine defines get_num_unfinished_requests() in a way that is currently incompatible with the same function for the V0 engine (though it's really the V0 engine's code which is in the wrong, I am choosing to mention this here as it seems relevant and the PR is still open). In the V0 engine, that function actually counts the number of SequenceGroup objects in the engine as opposed to the number of requests. This led to bugs mentioned in Issues #10949 and #11519, and an attempt to fix them is in PR #12428. However, the proposed fix is not compatible with the proposed V1 code. I will comment on that PR as well so that maintainers can adapt the V0 code to align with the proposed V1 code.

Copy link

mergify bot commented Feb 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @afeldman-nm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 5, 2025
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
@mergify mergify bot added documentation Improvements or additions to documentation ci/build labels Feb 10, 2025
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
@afeldman-nm afeldman-nm marked this pull request as ready for review February 11, 2025 15:21
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
@m-harmonic
Copy link

@m-harmonic

Thanks for working on this. I haven't had a chance to look into the specifics of the new implementation but also wanted to ask about n>1 behavior as this PR (#9302) reduced our throughput by about 3x. It seems that sequence forking was extremely important in the v0 engine. Is an equivalent strategy going to be used in v1? We're now waiting for n>1 support in the v1 engine before we can upgrade past 0.6.3. Thanks

Thanks for the feedback. I think this is because you didn't add --enable-prefix-caching. Starting with the PR you mentioned, vLLM relies on prefix caching to deduplicate the prompts in parallel sampling. As prefix caching is not enabled by default in V0, you may see slowdown if you didn't add --enable-prefix-caching. With V1, prefix caching is enabled by default, so you should be able to enjoy the speedup without any extra config.

Hi thanks for the reply. We did pass enable_prefix_caching=True in the engine arguments, which did not affect the throughput. Is there something else we should be doing?
Edit: I got a notification for a comment by @ afeldman-nm but oddly enough it's not showing up now. Per your question, we see the decreased throughput in VLLM's log stats. The number of generated tokens / sec falls from roughly 7k/sec to 2k/sec. In contrast, the prompt throughput rises by about 4-5x. This is for a value of n between 50 and 100. We did have prefix caching enabled in the engine args.

Hi @m-harmonic , sorry I deleted my response because I thought @WoosukKwon 's suggestion to use prefix caching was more relevant, although it appears that that did not resolve the issue. Although, based on the metrics you shared, it sounds like prefix caching is working but the decode throughput is the issue.

Thanks @afeldman-nm, appreciate the reply. One thing I'm wondering is whether matching n>1 throughput performance pre-0.6.4 is a priority for this change or follow-up work. It seems like fundamentally this new approach confers lower throughput. Relatedly, is there an architectural reason why this new approach is necessary in v1 or is it possible forking could be added later on if it makes a big difference?

Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
@afeldman-nm afeldman-nm changed the title [V1] V1 engine implements parallel sampling, 1/2: AsyncLLM support [V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) Feb 18, 2025
@afeldman-nm
Copy link
Contributor Author

Merged #13421 (LLMEngine support for parallel sampling) into this PR.

Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @afeldman-nm for the great work.

I am still reviewing the llm_engine.py changes but wanted to post the comments I have so far. I feel like there may still be room to simplify/condense things a bit more, and possibly move more logic out of the main classes.

Comment on lines +278 to +279
async for out in merge_async_iterators(*gens):
yield out[1] # out[0] is index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
async for out in merge_async_iterators(*gens):
yield out[1] # out[0] is index
async for _, out in merge_async_iterators(*gens):
yield out

@@ -241,6 +244,56 @@ async def generate(
await self.abort(request_id)
raise

async def _generate_parallel_sampling(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether it might be cleaner to also move this method into ParallelSamplingRequestManager, but it takes the base generate` method as an additional arg.

) -> AsyncGenerator[RequestOutput, None]:
"""Generate completions for parallel sampling requests."""
req_mgr = ParallelSamplingRequestManager(request_id, sampling_params)
n = req_mgr.n
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nti: no need for this variable

# Aggregate generators for n child requests
gens: List[AsyncGenerator[RequestOutput, None]] = []
for idx in range(n):
c_sampling_params = req_mgr.get_child_sampling_params(idx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about a single method that returns both the child sampling params and request id?

Child `sampling_params` instance.
"""
seed = self.sampling_params.seed
if seed is None and self.cached_child_sampling_params:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if seed is None and self.cached_child_sampling_params:
if self.cached_child_sampling_params:

# Note: will be sorted by index later
self.request_output.outputs.append(new_completion)

def _get_parent_request_output(self) -> RequestOutput:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggested rename:

Suggested change
def _get_parent_request_output(self) -> RequestOutput:
def _get_final_request_output(self) -> RequestOutput:

from vllm.sampling_params import RequestOutputKind, SamplingParams


class ParallelSamplingRequestManager:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT about renaming this to

Suggested change
class ParallelSamplingRequestManager:
class ParallelSamplingRequest:

Comment on lines +262 to +266
def _num_parallel_sampling_requests(self) -> int:
return len(self.parallel_parent_reqs)

def _num_parallel_sampling_child_requests(self) -> int:
return len(self.parallel_child_reqs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO these helper methods are unnecessary (just adds more LOC)

Comment on lines +273 to +274
self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else
self._do_reset_parallel_sampling)
Copy link
Member

@njhill njhill Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else
self._do_reset_parallel_sampling)
self._do_reset_parallel_sampling |= num_parallel_reqs > 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, looking again, this would be clearer:

Suggested change
self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else
self._do_reset_parallel_sampling)
if self.parallel_parent_requests:
self._do_reset_parallel_sampling = True

if num_parallel_reqs > 0 and len(request_outputs) > 0:
# Process parallel sampling child request outputs
return self._aggregate_parallel_sampling_outputs(request_outputs)
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: redundant else

Comment on lines +273 to +274
self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else
self._do_reset_parallel_sampling)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, looking again, this would be clearer:

Suggested change
self._do_reset_parallel_sampling = (True if num_parallel_reqs > 0 else
self._do_reset_parallel_sampling)
if self.parallel_parent_requests:
self._do_reset_parallel_sampling = True

prompt_adapter_request=prompt_adapter_request,
priority=priority)

def _add_request_parallel_sampling(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar thought to the async case, perhaps this could be moved into the other class? (along with the state ... could have a subclass with that just for the sync case)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation frontend v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: V1 parallel sampling support in LLMEngine [Feature]: V1 parallel sampling support in AsyncLLM
6 participants